import re
import io
import pickle
import numpy as np
import pandas as pd
from scipy.spatial.distance import squareform, pdist

from Bio.PDB import PDBParser
import biotite.structure as bs
from biotite.structure.io.pdb import  PDBFile
from biotite.database import rcsb

from openfold.np.residue_constants import restype_3to1, restype_order, atom_order


def get_seqs(data_path):
    """
    Note two chains in a dimer are seperated by ','
    """
    if data_path.endswith('csv'):
        data = pd.read_csv(data_path, index_col=0)
        pdb_ids = data.pdb_id.tolist()
        chain_ids = [re.sub(',', '', x) for x in data.chain_id.tolist()]
        pdbid_chainid = [pdb_ids[i]+'_'+chain_ids[i] for i in range(len(pdb_ids))]
        seqs = data.seq.tolist()
        seqs_len = [[int(z) for z in x.split(',')] for x in data.len.tolist()]
        sequences = {}
        for i, dimer_seq in enumerate(seqs):
            seq1, seq2 = dimer_seq.split(',')
            assert len(seq1) == seqs_len[i][0], "length of chain1 is wrong, please check"
            assert len(seq2) == seqs_len[i][1], "length of chain2 is wrong, please check"
            seq_value = {'dimer_seq': dimer_seq, 'chain1_seq': seq1, 'chain2_seq': seq2, 
                        'chain1_len': seqs_len[i][0], 'chain2_len': seqs_len[i][1]}
            sequences[pdbid_chainid[i]] = seq_value
    elif data_path.endswith('pickle'):
        with open(data_path, mode='rb') as f:
            sequences = pickle.load(f) 
    else:
        raise NameError('file should be csv or pickle.')
    return sequences


def extract_profiling_from_pdb(data_path, pdb_dir):
    """
    Extract profiling info (seq) from pdb files
    
    Args: 
        data_path: pdbid
        pdb_dir: 
    """ 
    logger.info('load dimer list')
    pdb_list = []
    with open(data_path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        new_line = line.split(' ')
        if len(new_line[0]) > 0:
            name = new_line[0]
            name = re.sub('\n', '', name)
            pdb_list.append(name)

    np_samples = []
    names = set([])
    for dimer in pdb_list:
        t = dimer.split('_') 
        p1, p2 = t[0], t[1]
        assert p1[:4] == p2[:4]
        pdbid = p1[:4].lower()
        
        # sort the chainid in alphabet order
        if p1[-1] < p2[-1]:
            chainid = p1[-1] + p2[-1]
        else:
            chainid = p2[-1] + p1[-1]
        name = pdbid + '_' + chainid  # to avoid redundant samples
        if name not in names:
            names.add(name)
            # load pdb file
            pdb_path = pdb_dir + pdbid[1:3] + '/pdb' + pdbid + '.pdb'
            try:
                f = open(pdb_path)
                pdb_str = f.read()
                f.close()
            except:
                #pdb_str = PDBFile.read(rcsb.fetch(pdbid, "pdb"))
                continue
            pdb_fh = io.StringIO(pdb_str)
            parser = PDBParser(QUIET=True)
            structure = parser.get_structure("none", pdb_fh)
            models = list(structure.get_models())
            chains = models[0].child_dict
        
            # get the chains we want
            dimer_seq = []
            for chain_id in chainid:
                if chain_id not in chains:
                    break
                chain = chains[chain_id]
                seq = []
                for res in chain:
                    if not (res.id[0] == " "):  # HETATM
                        continue
                    res_shortname = restype_3to1.get(res.resname, "X")
                    seq.append(res_shortname)
                seq = ''.join(seq)
                dimer_seq.append(seq)
                
            try:
                dimer_len = ','.join([str(len(x)) for x in dimer_seq])
                dimer_seq = ','.join(dimer_seq)
                sample = [pdbid, chainid, dimer_seq, dimer_len, structure.header["resolution"], structure.header["release_date"]]
                np_samples.append(sample)
            except:
                continue
        # convert to dataframe
    np_samples = pd.DataFrame(np_samples, columns=['pdb_id','chain_id','seq','len','resolution','release_date']) 
    return np_samples


def get_cb_dist_from_pdbfile(pdb_dir, pdb_id, chain_id, lengths):
    """
    Compute c_beta distance matrix between residues(if no c_beta atoms, use c_alpha atoms instead) for a single dimer
    Note that some homodimers have the same chainid liker "AA", we need to handle it specifically.
    Args:
        pdb_dir: directory of pdb files
        pdb_id: str, "3lf4"
        chain_id: str, "AB"
        lengths: list, [len1, len2]
    """
    pdb_file = pdb_dir + pdb_id[1:3] + '/pdb' + pdb_id + '.pdb'
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(pdb_id, pdb_file)
    model = structure[0]

    Cbeta_list = []
    for i, cid in enumerate(chain_id):
        chain = model[cid]
        for res in chain:
            if not (res.id[0] == " "):   # not standard atom
                continue
            try:
                coord = res['CB'].coord.tolist()  #[x,y,z]
            except:
                try:
                    coord = res['CA'].coord.tolist()
                except:
                    logger.info("{}_{} residue {} has not CA and CB".format(pdb_id, chain_id, restype_3to1[res.resname]))
                    return
            Cbeta_list.append(coord)
        if i == 0:
            #print('error:', pdb_id, chain_id, lengths, len(Cbeta_list))
            assert len(Cbeta_list) == lengths[0], "chain 1 length error"
    assert len(Cbeta_list) == sum(lengths), 'total length error'
    # compute pairwise distance matrix
    dist_map = squareform(pdist(np.array(Cbeta_list)))
    return dist_map


def get_dist_maps(data_path, pdb_dir):
    """compute inter-residue distance matrix for all dimers
    """
    data = pd.read_csv(data_path, index_col=0)
    dist_maps = {}
    for i, row in data.iterrows():
        lengths = [int(x) for x in row.len.split(',')]   # dimer sequence length
        pdb_id = row.pdb_id
        chain_id = row.chain_id
        if len(pdb_id) == 4:
            dist_map = get_cb_dist_from_pdbfile(pdb_dir, pdb_id, chain_id, lengths)
        if dist_map is not None:
            name = pdb_id + '_' + chain_id
            dist_maps[name] = dist_map
    return dist_maps


def get_distogram(dist_maps, min_bin=2.3125, max_bin=21.6875, num_bins=64):
    """convert distance maps into distogram 
    """
    bins = [0] + np.linspace(min_bin, max_bin, num_bins-1).tolist() + [np.inf]
    distogram = {}
    for name, dist_map in dist_maps.items():
        distogram[name] = np.digitize(dist_map, bins, right=False)   # start from 1
    return distogram


